# encoding=utf-8


def report(test_labels, test_predict, lables, cmp):
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn import metrics
    from sklearn.metrics import classification_report
    from pandas import DataFrame
    from sklearn.metrics import accuracy_score
    import numpy as np

    plt.rc('font',family='Arial',size='18')

    ## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
    #confusion_matrix_result = metrics.confusion_matrix(test_predict, test_labels, lables)
    #confusion_matrix_result = metrics.confusion_matrix(test_predict, test_labels)
    confusion_matrix_result = metrics.confusion_matrix(test_labels, test_predict)
    # print('The confusion matrix result:\n', confusion_matrix_result)
    print(classification_report(test_labels, test_predict))
    # df1=DataFrame(confusion_matrix_result, columns=lables, index=lables)
    df1 = DataFrame(confusion_matrix_result, columns=lables, index=lables) / confusion_matrix_result.sum(axis=1)[:, np.newaxis]
    print('The confusion matrix result:\n', df1)
    # 利用热力图对于结果进行可视化
    plt.figure(figsize=(8, 6))
    # sns.heatmap(df1, fmt="d", annot=True, cmap='Blues')
    sns.heatmap(df1, fmt=".2%", annot=True, cmap= cmp)
    # plt.xlabel('预测值', fontproperties = 'SimHei', fontsize = 22)
    # plt.ylabel('真实值\n', fontproperties = 'SimHei', fontsize = 22)
    
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()
    ## 利用accuracy（准确度）【预测正确的样本数目占总预测样本数目的比例】评估模型效果
    # print('The accuracy of tor is:', metrics.accuracy_score(test_labels, test_predict))
    # print('The precision_score of tor is:', metrics.precision_score(test_labels, test_predict, pos_label='tor'))
    # print('The recall_score of tor is:', metrics.recall_score(test_labels, test_predict, pos_label='tor'))
    # print('The f1_score of tor is:', metrics.f1_score(test_labels, test_predict, pos_label='tor'))

    # print('The accuracy of nat is:', metrics.accuracy_score(test_labels, test_predict))
    # print('The precision_score of nat is:', metrics.precision_score(test_labels, test_predict, pos_label='nat'))
    # print('The recall_score of nat is:', metrics.recall_score(test_labels, test_predict, pos_label='nat'))
    # print('The f1_score of nat is:', metrics.f1_score(test_labels, test_predict, pos_label='nat'))